{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNNs tutorial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# we assume that we have the dynet module in your path.\n",
    "# OUTDATED: we also assume that LD_LIBRARY_PATH includes a pointer to where libcnn_shared.so is.\n",
    "from dynet import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### An LSTM/RNN overview:\n",
    "\n",
    "An (1-layer) RNN can be thought of as a sequence of cells, $h_1,...,h_k$, where $h_i$ indicates the time dimenstion. \n",
    "\n",
    "Each cell $h_i$ has an input $x_i$ and an output $r_i$. In addition to $x_i$, cell $h_i$ receives as input also $r_{i-1}$.\n",
    "\n",
    "In a deep (multi-layer) RNN, we don't have a sequence, but a grid. That is we have several layers of sequences:\n",
    "\n",
    "* $h_1^3,...,h_k^3$ \n",
    "* $h_1^2,...,h_k^2$ \n",
    "* $h_1^1,...h_k^1$, \n",
    "\n",
    "Let $r_i^j$ be the output of cell $h_i^j$. Then:\n",
    "\n",
    "The input to $h_i^1$ is $x_i$ and $r_{i-1}^1$.\n",
    "\n",
    "The input to $h_i^2$ is $r_i^1$ and $r_{i-1}^2$,\n",
    "and so on.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The LSTM (RNN) Interface\n",
    "\n",
    "RNN / LSTM / GRU follow the same interface. We have a \"builder\" which is in charge of creating definining the parameters for the sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pc = ParameterCollection()\n",
    "NUM_LAYERS=2\n",
    "INPUT_DIM=50\n",
    "HIDDEN_DIM=10\n",
    "builder = LSTMBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)\n",
    "# or:\n",
    "# builder = SimpleRNNBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that when we create the builder, it adds the internal RNN parameters to the `ParameterCollection`.\n",
    "We do not need to care about them, but they will be optimized together with the rest of the network's parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "s0 = builder.initial_state()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "x1 = vecInput(INPUT_DIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "s1=s0.add_input(x1)\n",
    "y1 = s1.output()\n",
    "# here, we add x1 to the RNN, and the output we get from the top is y (a HIDEN_DIM-dim vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10,)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y1.npvalue().shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "s2=s1.add_input(x1) # we can add another input\n",
    "y2=s2.output()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If our LSTM/RNN was one layer deep, y2 would be equal to the hidden state. However, since it is 2 layers deep, y2 is only the hidden state (= output) of the last layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we were to want access to the all the hidden state (the output of both the first and the last layers), we could use the `.h()` method, which returns a list of expressions, one for each layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(exprssion 54/0, exprssion 66/0)\n"
     ]
    }
   ],
   "source": [
    "print s2.h()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The same interface that we saw until now for the LSTM, holds also for the Simple RNN:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "all layers: (exprssion 32/0, exprssion 42/0)\n"
     ]
    }
   ],
   "source": [
    "# create a simple rnn builder\n",
    "rnnbuilder=SimpleRNNBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)\n",
    "\n",
    "# initialize a new graph, and a new sequence\n",
    "rs0 = rnnbuilder.initial_state()\n",
    "\n",
    "# add inputs\n",
    "rs1 = rs0.add_input(x1)\n",
    "ry1 = rs1.output()\n",
    "print \"all layers:\", s1.h()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(exprssion 28/0, exprssion 38/0, exprssion 32/0, exprssion 42/0)\n"
     ]
    }
   ],
   "source": [
    "print s1.s()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To summarize, when calling `.add_input(x)` on an `RNNState` what happens is that the state creates a new RNN/LSTM column, passing it: \n",
    "1. the state of the current RNN column\n",
    "2. the input `x`\n",
    "\n",
    "The state is then returned, and we can call it's `output()` method to get the output `y`, which is the output at the top of the column. We can access the outputs of all the layers (not only the last one) using the `.h()` method of the state.\n",
    "\n",
    "**`.s()`** The internal state of the RNN may be more involved than just the outputs $h$. This is the case for the LSTM, that keeps an extra \"memory\" cell, that is used when calculating $h$, and which is also passed to the next column.  To access the entire hidden state, we use the `.s()` method. \n",
    "\n",
    "The output of `.s()` differs by the type of RNN being used. For the simple-RNN, it is the same as `.h()`. For the LSTM, it is more involved.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RNN h: (exprssion 74/0, exprssion 76/0)\n",
      "RNN s: (exprssion 74/0, exprssion 76/0)\n",
      "LSTM h: (exprssion 32/0, exprssion 42/0)\n",
      "LSTM s: (exprssion 28/0, exprssion 38/0, exprssion 32/0, exprssion 42/0)\n"
     ]
    }
   ],
   "source": [
    "rnn_h  = rs1.h()\n",
    "rnn_s  = rs1.s()\n",
    "print \"RNN h:\", rnn_h\n",
    "print \"RNN s:\", rnn_s\n",
    "\n",
    "\n",
    "lstm_h = s1.h()\n",
    "lstm_s = s1.s()\n",
    "print \"LSTM h:\", lstm_h\n",
    "print \"LSTM s:\", lstm_s\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the LSTM has two extra state expressions (one for each hidden layer) before the outputs h."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extra options in the RNN/LSTM interface\n",
    "\n",
    "**Stack LSTM** The RNN's are shaped as a stack: we can remove the top and continue from the previous state.\n",
    "This is done either by remembering the previous state and continuing it with a new `.add_input()`, or using\n",
    "we can access the previous state of a given state using the `.prev()` method of state.\n",
    "\n",
    "**Initializing a new sequence with a given state** When we call `builder.initial_state()`, we are assuming the state has random /0 initialization. If we want, we can specify a list of expressions that will serve as the initial state. The expected format is the same as the results of a call to `.final_s()`. TODO: this is not supported yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "s2=s1.add_input(x1)\n",
    "s3=s2.add_input(x1)\n",
    "s4=s3.add_input(x1)\n",
    "\n",
    "# let's continue s3 with a new input.\n",
    "s5=s3.add_input(x1)\n",
    "\n",
    "# we now have two different sequences:\n",
    "# s0,s1,s2,s3,s4\n",
    "# s0,s1,s2,s3,s5\n",
    "# the two sequences share parameters.\n",
    "\n",
    "assert(s5.prev() == s3)\n",
    "assert(s4.prev() == s3)\n",
    "\n",
    "s6=s3.prev().add_input(x1)\n",
    "# we now have an additional sequence:\n",
    "# s0,s1,s2,s6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(exprssion 184/0, exprssion 196/0)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s6.h()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(exprssion 180/0, exprssion 192/0, exprssion 184/0, exprssion 196/0)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s6.s()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Aside: memory efficient transduction\n",
    "The `RNNState` interface is convenient, and allows for incremental input construction.\n",
    "However, sometimes we know the sequence of inputs in advance, and care only about the sequence of\n",
    "output expressions. In this case, we can use the `add_inputs(xs)` method, where `xs` is a list of Expression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[exprssion 200/0, exprssion 206/0, exprssion 212/0] [(exprssion 198/0, exprssion 200/0), (exprssion 203/0, exprssion 206/0), (exprssion 209/0, exprssion 212/0)]\n"
     ]
    }
   ],
   "source": [
    "state = rnnbuilder.initial_state()\n",
    "xs = [x1,x1,x1]\n",
    "states = state.add_inputs(xs)\n",
    "outputs = [s.output() for s in states]\n",
    "hs =      [s.h() for s in states]\n",
    "print outputs, hs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is convenient.\n",
    "\n",
    "What if we do not care about `.s()` and `.h()`, and do not need to access the previous vectors? In such cases\n",
    "we can use the `transduce(xs)` method instead of `add_inputs(xs)`.\n",
    "`transduce` takes in a sequence of `Expression`s, and returns a sequence of `Expression`s.\n",
    "As a consequence of not returning `RNNState`s, `trnasduce` is much more memory efficient than `add_inputs` or a series of calls to `add_input`.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[exprssion 216/0, exprssion 222/0, exprssion 228/0]\n"
     ]
    }
   ],
   "source": [
    "state = rnnbuilder.initial_state()\n",
    "xs = [x1,x1,x1]\n",
    "outputs = state.transduce(xs)\n",
    "print outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### Character-level LSTM\n",
    "\n",
    "Now that we know the basics of RNNs, let's build a character-level LSTM language-model.\n",
    "We have a sequence LSTM that, at each step, gets as input a character, and needs to predict the next character."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import random\n",
    "from collections import defaultdict\n",
    "from itertools import count\n",
    "import sys\n",
    "\n",
    "LAYERS = 2\n",
    "INPUT_DIM = 50 \n",
    "HIDDEN_DIM = 50  \n",
    "\n",
    "characters = list(\"abcdefghijklmnopqrstuvwxyz \")\n",
    "characters.append(\"<EOS>\")\n",
    "\n",
    "int2char = list(characters)\n",
    "char2int = {c:i for i,c in enumerate(characters)}\n",
    "\n",
    "VOCAB_SIZE = len(characters)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "pc = ParameterCollection()\n",
    "\n",
    "\n",
    "srnn = SimpleRNNBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)\n",
    "lstm = LSTMBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)\n",
    "\n",
    "params = {}\n",
    "params[\"lookup\"] = pc.add_lookup_parameters((VOCAB_SIZE, INPUT_DIM))\n",
    "params[\"R\"] = pc.add_parameters((VOCAB_SIZE, HIDDEN_DIM))\n",
    "params[\"bias\"] = pc.add_parameters((VOCAB_SIZE))\n",
    "\n",
    "# return compute loss of RNN for one sentence\n",
    "def do_one_sentence(rnn, sentence):\n",
    "    # setup the sentence\n",
    "    renew_cg()\n",
    "    s0 = rnn.initial_state()\n",
    "    \n",
    "    \n",
    "    R = parameter(params[\"R\"])\n",
    "    bias = parameter(params[\"bias\"])\n",
    "    lookup = params[\"lookup\"]\n",
    "    sentence = [\"<EOS>\"] + list(sentence) + [\"<EOS>\"]\n",
    "    sentence = [char2int[c] for c in sentence]\n",
    "    s = s0\n",
    "    loss = []\n",
    "    for char,next_char in zip(sentence,sentence[1:]):\n",
    "        s = s.add_input(lookup[char])\n",
    "        probs = softmax(R*s.output() + bias)\n",
    "        loss.append( -log(pick(probs,next_char)) )\n",
    "    loss = esum(loss)\n",
    "    return loss\n",
    " \n",
    "\n",
    "# generate from model:\n",
    "def generate(rnn):\n",
    "    def sample(probs):\n",
    "        rnd = random.random()\n",
    "        for i,p in enumerate(probs):\n",
    "            rnd -= p\n",
    "            if rnd <= 0: break\n",
    "        return i\n",
    "    \n",
    "    # setup the sentence\n",
    "    renew_cg()\n",
    "    s0 = rnn.initial_state()\n",
    "    \n",
    "    R = parameter(params[\"R\"])\n",
    "    bias = parameter(params[\"bias\"])\n",
    "    lookup = params[\"lookup\"]\n",
    "    \n",
    "    s = s0.add_input(lookup[char2int[\"<EOS>\"]])\n",
    "    out=[]\n",
    "    while True:\n",
    "        probs = softmax(R*s.output() + bias)\n",
    "        probs = probs.vec_value()\n",
    "        next_char = sample(probs)\n",
    "        out.append(int2char[next_char])\n",
    "        if out[-1] == \"<EOS>\": break\n",
    "        s = s.add_input(lookup[next_char])\n",
    "    return \"\".join(out[:-1]) # strip the <EOS>\n",
    "        \n",
    "\n",
    "# train, and generate every 5 samples\n",
    "def train(rnn, sentence):\n",
    "    trainer = SimpleSGDTrainer(pc)\n",
    "    for i in xrange(200):\n",
    "        loss = do_one_sentence(rnn, sentence)\n",
    "        loss_value = loss.value()\n",
    "        loss.backward()\n",
    "        trainer.update()\n",
    "        if i % 5 == 0: \n",
    "            print loss_value,\n",
    "            print generate(rnn)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that:\n",
    "1. We pass the same rnn-builder to `do_one_sentence` over and over again.\n",
    "We must re-use the same rnn-builder, as this is where the shared parameters are kept.\n",
    "2. We `renew_cg()` before each sentence -- because we want to have a new graph (new network) for this sentence.\n",
    "The parameters will be shared through the model and the shared rnn-builder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "142.737915039 lvawhaevbxulc yxg esuh vkyb gymj dzcnwgq dcjzzk\n",
      "84.1147460938 woifoa odp jpt gxjofkaattj\n",
      "44.212223053 a q io  uoopr ouxducmwi jfxa j\n",
      "23.4485988617 p tctflr\n",
      "9.73490333557 w\n",
      "3.23773050308 yaqzteu pux oa rntd  bxumu yyvvfalejuyhed over the lazy dog\n",
      "1.06309330463 a quick browe fow jumped over the lazy dog\n",
      "0.671298980713 a quick broyn  ox jumped over the lazy dog\n",
      "0.490513861179 a quick brown fox jumped over the lazy dog\n",
      "0.386095941067 a quick brown fox jumped over the lazy dog\n",
      "0.318082690239 a quick brown fox jumped over the lazy dog\n",
      "0.270276993513 a quick brown fox jumped over the lazy dog\n",
      "0.234851941466 a quick brown foz jumped over the lazy dog\n",
      "0.207555636764 a quick brown fox jumped over the lazy dog\n",
      "0.185884565115 a quick brown fox jumped over the lazy dog\n",
      "0.168265148997 a quiuk brown fox jumped over jhe lazy dog\n",
      "0.153665527701 a quick brown fox jumped over the lazy dog\n",
      "0.141367897391 a quick brown fox jumped over the lazy dog\n",
      "0.130873680115 a quick brown fox jumped over the lazy dog\n",
      "0.121810980141 a quick brown fox jumped over the lazy dog\n",
      "0.113908931613 a quick brown fox jumped over the lazy dog\n",
      "0.106958284974 a quick brown fox jumped over the lazy dog\n",
      "0.100796818733 a quick brown fox jumped over the lazy dog\n",
      "0.0953008085489 a quick brown fox jumped over the lazy dog\n",
      "0.090367347002 a zuick brown for jumped over the lazy dog\n",
      "0.0859087407589 a quick brown fox jumped over the lazy dog\n",
      "0.0818664133549 a quick brown fox jumped over the lazy dog\n",
      "0.0781841799617 a quick brown fox jumped over the lazy dog\n",
      "0.0748091414571 a quick brown fox jumped over the lazy dog\n",
      "0.0717144161463 a quick brown fox jumped over the lazy dog\n",
      "0.0688648074865 a quick brown fox jumped over the lazy dog\n",
      "0.0662328600883 a quick brown fox jumped over the lazy dog\n",
      "0.0637853741646 a quick brown fox jumped over the lazy dog\n",
      "0.0615109689534 a quick brown fox jumped over the lazy dog\n",
      "0.0593910999596 a quick brown fox jumped over the lazy dog\n",
      "0.0574130378664 a quick brown fox jumped over the lazy dog\n",
      "0.0555621087551 a quick brown fox jumped over the lazy dog\n",
      "0.0538215488195 a quick brown fox jumped over the lazy dog\n",
      "0.0521896965802 a quick brown fox jumped over the lazy dog\n",
      "0.0506477579474 a quick brown fox jumped over the lazy dog\n"
     ]
    }
   ],
   "source": [
    "sentence = \"a quick brown fox jumped over the lazy dog\"\n",
    "train(srnn, sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "141.891098022 aoyekppy   mocalmz xk atc jlg oaddk\n",
      "128.925964355 hempeyud  ki\n",
      "121.445785522 qpveti  fyobec  ztmr eioknnueh ehecdvabxmc ydpmdm\n",
      "110.670722961 z buws lmy vvrw\n",
      "93.5055999756 vueoa cprlnkrd o ocazk nb olegiep o fftr t\n",
      "82.1586227417 zj rvsr oej c toz bnarreow  fffj\n",
      "67.430847168  rzfik qoyc ohe hqe  oea uitet ou udjkpme oak kdk oe fbu kcz fox dfoprl too o rxat luurnfowrrtj rbtram to url xlj  okrr ooe otm hcy roab llsg doy ifzw rrbow rbowwb oke jxpee\n",
      "54.9477920532 ba uiy doge she ueeze  oejv\n",
      "43.3301696777  qquc crgibbroej  oxne ove rr\n",
      "34.4687461853  uqckk owrbfo og uouk doge l\n",
      "25.5408306122   reuk lfr own fox juamd ov\n",
      "18.9417610168   qojn doo broww boan jover txe zacy moen crlw numk fox joge overwa trez quqk browx  ox ruor oro fow j uoez kon fror bowe luccmd ogwr foy jodmoed ox\n",
      "13.1646575928  qucy dov\n",
      "9.46595668793 wiuuik brttxl laed over tre lazy dog\n",
      "5.6522898674  rukc irown fox juaped over the lazy dov\n",
      "3.38144731522 a quick brown fox jumver the lazy dog\n",
      "1.80010521412 a bfoin fox jumped ovk  fox luick brown fox jumped over the lazy dog\n",
      "1.30616080761 a quic brownn fox jumped over the lazy dog\n",
      "1.02201879025 a quick brown fox jumped over the lazy dog\n",
      "0.83735615015  qucck brown fox jcmped over the lazy dog\n",
      "0.708056390285 a quickz brown fox jumped over the lazy dog\n",
      "0.612650871277 a quick brown fox jumped over the lazy dog\n",
      "0.539469838142 a quick brown fox jumped over thel lazy dog\n",
      "0.481610894203 va quick brown fox jumped over the lazy dog\n",
      "0.434762001038 a quuck dovtbown fox jumped over the lazy dog\n",
      "0.396079242229 a quick brown fox jumped over the lazy dog\n",
      "0.363606244326 a quick brown fox jumped over the laza dog\n",
      "0.335973978043 a quick brown fox jumped over the lazy dog\n",
      "0.312186658382 a quick brown fox jumped over the lazy dog\n",
      "0.291498303413 a quick brown fox qu\n",
      "0.273335546255 a quick brown fox jumped ove\n",
      "0.257278442383 a quick brown fox jumped over the lazy dog\n",
      "0.242971763015 a quick brown fox jumped over the lazy dog\n",
      "0.230153128505 a quick brown fox jumped over the lazy dog\n",
      "0.218599274755 a quick brown fox jumped over the lazy dog\n",
      "0.208135351539 a quick brown fox jumped over the lazy dog\n",
      "0.198613137007 a quick brown fox jumped over tie lazy dog\n",
      "0.189909905195 a quick brown fox jumped over the lazy dog\n",
      "0.181928783655 a quick brown fox jumped over the lazy dog\n",
      "0.174587100744 a quick brown fox jumped over the lazy dog\n"
     ]
    }
   ],
   "source": [
    "sentence = \"a quick brown fox jumped over the lazy dog\"\n",
    "train(lstm, sentence)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model seem to learn the sentence quite well.\n",
    "\n",
    "Somewhat surprisingly, the Simple-RNN model learn quicker than the LSTM!\n",
    "\n",
    "How can that be?\n",
    "\n",
    "The answer is that we are cheating a bit. The sentence we are trying to learn\n",
    "has each letter-bigram exactly once. This means a simple trigram model can memorize\n",
    "it very well.\n",
    "\n",
    "Try it out with more complex sequences.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "332.651580811 a quick brown fox jumped over the lazy dog\n",
      "133.209350586 a quick brown fox jumped over the lazy doe hu yum xd the\n",
      "65.0720596313 azquick brown fox jumped over ohe iog\n",
      "31.5592880249 a quick brown fox jumpedrovtretpede pretzelz are makink ma tui idmilt\n",
      "13.2322559357 theve prwtumpede mhxtjaypny mreticv\n",
      "1.87829053402 thele pretzelb mre laki loet dre za tuiri mtoina ma qui irwt ere sa taetsdaca qamtuioe ma ick mrolnn mhetsirstyyza qa luijuoethetsepsaaya quirk brmtze ehersjlyaa aumu orkrbtoeqz lrea quijk jrowza quiquihi sakiny mr tui ss thels theqetursy famtzi maethehe iretza lamqzd zretsels area qhirk browna yhetza quirkt rxkwn mox ja isi mq thirsty\n",
      "0.680327475071 these pretzels are makind me thirsty\n",
      "0.176128521562 these pretzels are making me thirsty\n",
      "0.126334354281 these pretzels are making me thirsty\n",
      "0.10075186193 these pretzels are making me thirsty\n",
      "0.0846510156989 these pretzels are making me thirsty\n",
      "0.0734022557735 these pretzels are making me thirsty\n",
      "0.0650328546762 these pretzels are making me thirsty\n",
      "0.0585154108703 these pretzels are making me thirsty\n",
      "0.0532807298005 these pretzels are making me thirsty\n",
      "0.0489665567875 these pretzels are making me thirsty\n",
      "0.0453444086015 these pretzels are making me thirsty\n",
      "0.0422535128891 these pretzels are making me thirsty\n",
      "0.0395833179355 these pretzels are making me thirsty\n",
      "0.0372485220432 these mretzels are making me thirsty\n",
      "0.0351839251816 these pretzels are making me thirsty\n",
      "0.0333509668708 these pretzels are making me thirsty\n",
      "0.0317104011774 these pretzels are making me thirsty\n",
      "0.0302277039737 these pretzels are making me thirsty\n",
      "0.0288887582719 these pretzels are making me thirsty\n",
      "0.0276643745601 these pretzels are making me thirsty\n",
      "0.0265435613692 these pretzels are making me thirsty\n",
      "0.0255212895572 these pretzels are making me thirsty\n",
      "0.0245705824345 these pretzels are making me thirsty\n",
      "0.0236932244152 these pretzels are making me thirsty\n",
      "0.0228785891086 these pretzels are making me thirsty\n",
      "0.0221205893904 these pretzels are making me thirsty\n",
      "0.0214090794325 these pretzels are making me thirsty\n",
      "0.0207556784153 these pretzels are making me thirsty\n",
      "0.0201329570264 these pretzels are making me thirsty\n",
      "0.0195484217256 these pretzels are making me thirsty\n",
      "0.0190003421158 these pretzels are making me thirsty\n",
      "0.0184785164893 these pretzels are making me thirsty\n",
      "0.0179911740124 these pretzels are making me thirsty\n",
      "0.0175334792584 these pretzels are making me thirsty\n"
     ]
    }
   ],
   "source": [
    "train(srnn, \"these pretzels are making me thirsty\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
